import scipy
import numpy as np
import sklearn.decomposition
import logging
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd
import matplotlib
import itertools

class PCAExplorer(object):
    """
    Class to generate and visualize PCA results
    """
    colors = ['r', 'b', 'y', 'g', 'c', 'm', 'k']

    def __init__(self, trainingData, idColumn='MouseID', n_components=2):
        try:
            #Standardize the data
            self.groupIds, self.groups, self.data = self._importData(trainingData, idColumn)
            self.components = trainingData.columns.tolist()
            self.components.remove(idColumn)

            self.pca = sklearn.decomposition.PCA(n_components=n_components)
            self.pca.fit(self.data)
            cName = lambda prefix, cnt: [format('%s%d' % (prefix, n)) for n in range(1, cnt + 1)]
            self.compNames = cName('PCA', self.pca.n_components)
        except Exception as e:
            logging.error('Generating PCA Explorer failed! %s', e)

    @staticmethod
    def _zData(data):
        """
        Calculate the z-score on the data and filter NaN (set zeros)
        """
        zdata = data.apply(scipy.stats.mstats.zscore)
        if zdata.isnull().any().any():
            logging.warning('Filtering NaN in standardized data!')
            zdata = zdata.apply(np.nan_to_num)

        return zdata

    @staticmethod
    def _getColors(ids):
        return zip(itertools.count(), ids, itertools.cycle(PCAExplorer.colors))

    @staticmethod
    def _importData(data, idColumn):
        zdata = PCAExplorer._zData(data.drop(idColumn, axis=1))
        groupIds = data[idColumn].cat.categories.tolist()
        group = data[idColumn].cat.codes
        return groupIds, group.values, zdata.values


    def plotComponents(self, ax=None):
        """
        Generate a correlation plot between PCA components and input components
        """
        if ax is None:
            f, ax = plt.subplots()
            f.show()

        img = ax.pcolormesh(self.pca.components_, cmap='RdBu', vmin=-1, vmax=1)

        ax.set_title('PCA components correlation, %2.0f%% Variance coverage' % (sum(self.pca.explained_variance_ratio_) * 100.0))
        ax.set_xticklabels(self.components, minor=False)
        ax.set_yticklabels(self.compNames, minor=False)
        plt.setp(ax.xaxis.get_majorticklabels(), rotation=45)
        ax.set_xticks(np.arange(len(self.components)) + 0.5, minor=False)
        ax.set_yticks(np.arange(self.pca.n_components) + 0.5, minor=False)
        plt.colorbar(img, ax=ax)

        return ax

    def plot(self, data=None, idColumn='MouseID', plottype='scatter'):
        """
        Generate a scatter or path plot in PCA space
        """
        if data is None:
            grpIDs, grps, pts = self.groupIds, self.groups, self.data
        else:
            grpIDs, grps, pts = self._importData(data, idColumn=idColumn)

        pts = self.pca.transform(pts)
        N = len(pts)
        dim = self.pca.n_components
        if dim == 2:
            f, ax = plt.subplots()
            for i, name, color in self._getColors(grpIDs):
                ax.scatter(pts[grps == i, 0], pts[grps == i, 1], label=name, c=color, alpha=0.5)
        elif dim == 3:
            f = plt.figure()
            ax = f.add_subplot(111, projection='3d')
            for i, name, color in self._getColors(grpIDs):
                X = pts[grps == i, 0]
                Y = pts[grps == i, 1]
                Z = pts[grps == i, 2]
                if plottype == 'scatter':
                    ax.scatter(X, Y, Z, label=name, c=color, alpha=0.5)
                elif plottype == 'path':
                    L = len(X)
                    for n in range(L - 1):
                        ax.plot(X[n:n+2], Y[n:n+2], Z[n:n+2], alpha=float(n)/(L-1), color=color)

            ax.set_zlabel(self.compNames[2])
        else:
            raise NotImplementedError('Cannot visualize 1 PCA results')

        ax.set_xlabel(self.compNames[0])
        ax.set_ylabel(self.compNames[1])

        ax.legend()
        ax.set_title('PCA - On Data N = %d, %2.0f%% Variance coverage' % (N, sum(self.pca.explained_variance_ratio_) * 100.0))
        f.show()

